#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -- coding: utf-8 --
'''
# @Time : 2023/10/16 21:31
# @Author: from https://github.com/wenet-e2e/wenet/blob/main/tools/compute-wer.py
# @Modified By Shiyu He
# @University : Xinjiang University
'''

import os, re, sys
import json
import random

puncts = [
    '!', ',', '?', '、', '。', '！', '，', '；', '？', '：', '「', '」', '『', '』',
    '《', '》', '.', ';', ':', '(', ')', '[', ']', '{', '}', '"', "'", '...',
    '—', '-', '/', '\\', '%', '￥', '$', '·', '`', '‘', '’', '“', '”', '~', '@', '#',
    '&', '*', '_', '+', '=', '|', '<', '>', '^', '【', '】'
    ]

def remove_punctuation(sentence: str):
    try:
        pattern = '[' + re.escape(''.join(puncts)) + ']'
        return re.sub(pattern, '', sentence)
    
    except re.error as re_error:
        print(f"Error in remove_punctuation: {re_error}")
        return sentence


# 定义一个计算语音复述评分的类
class RepetitionEditDist:
    def __init__(self):
        self.data = {}
        self.space = []
        self.cost = {}
        self.cost['cor'] = 0
        self.cost['sub'] = 1
        self.cost['del'] = 1
        self.cost['ins'] = 1

    #   lab: 标签  rec: 识别结果
    def get_edit_dist(self, lab: list, rec: list):
        try:
            lab.insert(0, '')
            rec.insert(0, '')
            while len(self.space) < len(lab):
                self.space.append([])
            for row in self.space:
                for element in row:
                    element['dist'] = 0
                    element['error'] = 'non'
                while len(row) < len(rec):
                    row.append({'dist': 0, 'error': 'non'})
            for i in range(len(lab)):
                self.space[i][0]['dist'] = i
                self.space[i][0]['error'] = 'del'
            for j in range(len(rec)):
                self.space[0][j]['dist'] = j
                self.space[0][j]['error'] = 'ins'
            self.space[0][0]['error'] = 'non'
            for token in lab:
                if token not in self.data and len(token) > 0:
                    self.data[token] = {
                        'all': 0,
                        'cor': 0,
                        'sub': 0,
                        'ins': 0,
                        'del': 0
                    }
            for token in rec:
                if token not in self.data and len(token) > 0:
                    self.data[token] = {
                        'all': 0,
                        'cor': 0,
                        'sub': 0,
                        'ins': 0,
                        'del': 0
                    }
            # Computing edit distance
            for i, lab_token in enumerate(lab):
                for j, rec_token in enumerate(rec):
                    if i == 0 or j == 0:
                        continue
                    min_dist = sys.maxsize
                    min_error = 'none'
                    dist = self.space[i - 1][j]['dist'] + self.cost['del']
                    error = 'del'
                    if dist < min_dist:
                        min_dist = dist
                        min_error = error
                    dist = self.space[i][j - 1]['dist'] + self.cost['ins']
                    error = 'ins'
                    if dist < min_dist:
                        min_dist = dist
                        min_error = error
                    if lab_token == rec_token:
                        dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
                        error = 'cor'
                    else:
                        dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
                        error = 'sub'
                    if dist < min_dist:
                        min_dist = dist
                        min_error = error
                    self.space[i][j]['dist'] = min_dist
                    self.space[i][j]['error'] = min_error

            # Tracing back
            result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
            i = len(lab) - 1
            j = len(rec) - 1
            k = i if len(lab) > len(rec) else j
            while True:
                if self.space[i][j]['error'] == 'cor':  # correct
                    if len(lab[i]) > 0:
                        self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
                        self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
                        result['all'] = result['all'] + 1
                        result['cor'] = result['cor'] + 1
                    i = i - 1
                    j = j - 1
                    k = k - 1
                elif self.space[i][j]['error'] == 'sub':  # substitution
                    if len(lab[i]) > 0:
                        self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
                        self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
                        result['all'] = result['all'] + 1
                        result['sub'] = result['sub'] + 1
                    i = i - 1
                    j = j - 1
                    k = k - 1
                elif self.space[i][j]['error'] == 'del':  # deletion
                    if len(lab[i]) > 0:
                        self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
                        self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
                        result['all'] = result['all'] + 1
                        result['del'] = result['del'] + 1
                    i = i - 1
                    k = k - 1
                elif self.space[i][j]['error'] == 'ins':  # insertion
                    if len(rec[j]) > 0:
                        self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
                        result['ins'] = result['ins'] + 1
                    j = j - 1
                    k = k - 1
                elif self.space[i][j]['error'] == 'non':  # starting point
                    break
                else:  # shouldn't reach here
                    print(
                        'this should not happen , i = {i} , j = {j} , error = {error}'
                        .format(i=i, j=j, error=self.space[i][j]['error']))
            return result

        except Exception as e:
            print(f"Error in get_edit_dist: {e}")
            return None

class SpeechRepetitionScore:
    def __init__(self, Full=10, level=3):
        self.level = level
        self.Full = Full

    def level_to_weight(self):
            if self.level == 1:
                return 1.0
            elif self.level == 2:
                return 0.8
            elif self.level == 3:
                return 0.7
            else:
                raise ValueError("等级(level)必须是1、2或3。")


    def computer_score(self, result: dict):
        try:
            level_weight = self.level_to_weight()
            
            remove_score_del = result['del'] * 0.6 * level_weight
            remove_score_sub = result['sub'] * 1 * level_weight
            remove_score_ins = result['ins'] * 0.6 * level_weight
            if len(lab) > 0:
                remove_score = (remove_score_del + remove_score_sub + remove_score_ins) / len(lab) * self.Full
            else:
                return 0

            if remove_score >= 4:
                pass
                # 引入语义相似度分析
            else:
                score = self.Full - remove_score
            return score

        except Exception as e:
            print(f"Error in computer_score: {e}")
            return None


if __name__ == '__main__':
    try:
        LEVEL = 2
        lab = "今天不是个好日子"
        rec = "今天是个好日子"

        lab_token = remove_punctuation(lab) 
        rec_token = remove_punctuation(rec)

        if lab == rec:
            score = 10
        else:
            SpeechRepetition = RepetitionEditDist()
            editdist_result = SpeechRepetition.get_edit_dist(list(lab), list(rec))
            # 初始化一个评分类，传入满分分值，考试等级
            speechrepetitionscore =  SpeechRepetitionScore(10, LEVEL)
            score = speechrepetitionscore.computer_score(editdist_result)
        print("score:", score)

    except ValueError as ve:
        print(f"错误：{ve}")
    
    except Exception as e:
        print(f"发生异常：{e}")
